Chapter 4: Geocentric Models¶
[1]:
%load_ext jupyter_black
[2]:
import jax
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
import numpyro.distributions as dist
import pandas as pd
import plotly
import plotly.graph_objects as go
import plotly.io as pio
from scipy import stats, optimize
from scipy.interpolate import BSpline
pd.options.plotting.backend = "plotly"
seed = 84735
pio.templates.default = "plotly_white"
rng = np.random.default_rng(seed=seed)
jrng = jax.random.key(seed)
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Code¶
Code 4.1¶
[3]:
steps = dist.Uniform(low=-1, high=1).sample(jrng, sample_shape=(1_000, 16))
[4]:
steps = -1 + 2 * stats.uniform.rvs(size=(1_000, 16))
[5]:
pd.DataFrame(steps.sum(axis=1)).plot(kind="hist")
Code 4.2¶
[6]:
steps = dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(12,))
jnp.prod(1 + steps)
[6]:
Array(1.8501737, dtype=float32)
Code 4.3¶
[7]:
growth = jnp.prod(
1 + dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
az.plot_density({"growth": growth}, hdi_prob=1)
[7]:
array([[<Axes: title={'center': 'growth'}>]], dtype=object)
Code 4.4¶
[8]:
big = jnp.prod(
1 + dist.Uniform(low=0, high=0.5).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
small = jnp.prod(
1 + dist.Uniform(low=0, high=0.1).sample(jrng, sample_shape=(10_000, 12)), axis=1
)
az.plot_density({"big": big, "small": small}, hdi_prob=1)
[8]:
array([[<Axes: title={'center': 'big'}>,
<Axes: title={'center': 'small'}>]], dtype=object)
Code 4.5¶
[9]:
log_big = jnp.log(
jnp.prod(
1 + dist.Uniform(low=0, high=0.5).sample(jrng, sample_shape=(10_000, 12)),
axis=1,
)
)
ax = az.plot_density({"log_big": log_big}, hdi_prob=1)
x = jnp.sort(log_big)
gaussian = jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x))
ax[0][0].plot(x, gaussian, "--")
[9]:
[<matplotlib.lines.Line2D at 0x7d87aef7a720>]
Code 4.6¶
[10]:
w = 6
n = 9
p_grid = jnp.linspace(0, 1, 100)
posterior = jnp.exp(dist.Binomial(total_count=n, probs=p_grid).log_prob(w)) * jnp.exp(
dist.Uniform(low=0, high=1).log_prob(p_grid)
)
posterior /= posterior.sum()
pd.DataFrame(posterior, index=p_grid).plot()
Code 4.7¶
[11]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
Code 4.8¶
[12]:
df
[12]:
| height | weight | age | male | |
|---|---|---|---|---|
| 0 | 151.765 | 47.825606 | 63.0 | 1 |
| 1 | 139.700 | 36.485807 | 63.0 | 0 |
| 2 | 136.525 | 31.864838 | 65.0 | 0 |
| 3 | 156.845 | 53.041914 | 41.0 | 1 |
| 4 | 145.415 | 41.276872 | 51.0 | 0 |
| ... | ... | ... | ... | ... |
| 539 | 145.415 | 31.127751 | 17.0 | 1 |
| 540 | 162.560 | 52.163080 | 31.0 | 1 |
| 541 | 156.210 | 54.062497 | 21.0 | 0 |
| 542 | 71.120 | 8.051258 | 0.0 | 1 |
| 543 | 158.750 | 52.531624 | 68.0 | 1 |
544 rows × 4 columns
Code 4.9¶
[13]:
df.describe()
[13]:
| height | weight | age | male | |
|---|---|---|---|---|
| count | 544.000000 | 544.000000 | 544.000000 | 544.000000 |
| mean | 138.263596 | 35.610618 | 29.344393 | 0.472426 |
| std | 27.602448 | 14.719178 | 20.746888 | 0.499699 |
| min | 53.975000 | 4.252425 | 0.000000 | 0.000000 |
| 25% | 125.095000 | 22.007717 | 12.000000 | 0.000000 |
| 50% | 148.590000 | 40.057844 | 27.000000 | 0.000000 |
| 75% | 157.480000 | 47.209005 | 43.000000 | 1.000000 |
| max | 179.070000 | 62.992589 | 88.000000 | 1.000000 |
Code 4.10¶
[14]:
df["height"]
[14]:
0 151.765
1 139.700
2 136.525
3 156.845
4 145.415
...
539 145.415
540 162.560
541 156.210
542 71.120
543 158.750
Name: height, Length: 544, dtype: float64
Code 4.11¶
[15]:
df2 = df[df["age"] >= 18]
Code 4.12¶
[16]:
x = jnp.linspace(100, 250)
pd.DataFrame(stats.norm.pdf(x, loc=178, scale=20), index=x).plot()
Code 4.13¶
[17]:
x = jnp.linspace(-10, 60)
pd.DataFrame(stats.uniform.pdf(x, loc=0, scale=50), index=x).plot()
Code 4.14¶
[18]:
_, jrng = jax.random.split(jrng)
sample_mu = dist.Normal(loc=178, scale=20).sample(jrng, (10_000,))
_, jrng = jax.random.split(jrng)
sample_sigma = dist.Uniform(low=0, high=50).sample(jrng, (10_000,))
_, jrng = jax.random.split(jrng)
prior_predictive = dist.Normal(loc=sample_mu, scale=sample_sigma).sample(jrng)
az.plot_density({"Prior Predictive Distribution": prior_predictive}, hdi_prob=1)
[18]:
array([[<Axes: title={'center': 'Prior Predictive Distribution'}>]],
dtype=object)
[19]:
def adult_height_model(priors, heights):
mu = numpyro.sample(
"mu", dist.Normal(loc=priors["mu_mean"], scale=priors["mu_scale"])
)
sigma = numpyro.sample(
"sigma", dist.Uniform(low=priors["sigma_low"], high=priors["sigma_high"])
)
numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=heights)
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=10_000)(
jrng,
priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
heights=None,
)
az.plot_density(prior_samples, hdi_prob=1)
[19]:
array([[<Axes: title={'center': 'height'}>,
<Axes: title={'center': 'mu'}>,
<Axes: title={'center': 'sigma'}>]], dtype=object)
Code 4.15¶
[20]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=10_000)(
jrng,
priors={"mu_mean": 178, "mu_scale": 100, "sigma_low": 0, "sigma_high": 50},
heights=None,
)
az.plot_density(prior_samples, hdi_prob=1)
[20]:
array([[<Axes: title={'center': 'height'}>,
<Axes: title={'center': 'mu'}>,
<Axes: title={'center': 'sigma'}>]], dtype=object)
Code 4.16¶
[21]:
mu_list = jnp.linspace(start=150, stop=160, num=100)
sigma_list = jnp.linspace(start=7, stop=9, num=100)
mesh = jnp.meshgrid(mu_list, sigma_list)
posterior = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
posterior["LL"] = jax.vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(df2.height.values))
)(posterior["mu"], posterior["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(posterior["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(posterior["sigma"])
posterior["prob"] = posterior["LL"] + logprob_mu + logprob_sigma
posterior["prob"] = jnp.exp(posterior["prob"] - jnp.max(posterior["prob"]))
Code 4.17¶
[22]:
plt.contour(
posterior["mu"].reshape(100, 100),
posterior["sigma"].reshape(100, 100),
posterior["prob"].reshape(100, 100),
)
plt.show()
Code 4.18¶
[23]:
plt.imshow(
posterior["prob"].reshape(100, 100),
origin="lower",
extent=(150, 160, 7, 9),
aspect="auto",
)
plt.show()
Code 4.19¶
[24]:
prob = posterior["prob"] / jnp.sum(posterior["prob"])
sample_rows = dist.Categorical(probs=prob).sample(jrng, (int(1e4),))
sample_mu = posterior["mu"][sample_rows]
sample_sigma = posterior["sigma"][sample_rows]
[25]:
pd.DataFrame({"mu": sample_mu, "sigma": sample_sigma}).plot(
kind="scatter", x="mu", y="sigma", backend="matplotlib", alpha=0.1
)
[25]:
<Axes: xlabel='mu', ylabel='sigma'>
Code 4.20¶
[26]:
az.plot_kde(sample_mu)
[26]:
<Axes: >
[27]:
az.plot_kde(sample_sigma)
[27]:
<Axes: >
Code 4.22¶
[28]:
print(f"mu 89% HPDI: {numpyro.diagnostics.hpdi(sample_mu, prob=0.89)}")
print(f"sigma 89% HPDI: {numpyro.diagnostics.hpdi(sample_sigma, prob=0.89)}")
mu 89% HPDI: [154.0404 155.35353]
sigma 89% HPDI: [7.2828283 8.212121 ]
Code 4.23¶
[29]:
df3 = df2["height"].sample(n=20, random_state=seed)
Code 4.24¶
[30]:
mu_list = jnp.linspace(start=100, stop=170, num=200)
sigma_list = jnp.linspace(start=4, stop=20, num=200)
mesh = jnp.meshgrid(mu_list, sigma_list)
posterior2 = {"mu": mesh[0].reshape(-1), "sigma": mesh[1].reshape(-1)}
posterior2["LL"] = jax.vmap(
lambda mu, sigma: jnp.sum(dist.Normal(mu, sigma).log_prob(df3.values))
)(posterior2["mu"], posterior2["sigma"])
logprob_mu = dist.Normal(178, 20).log_prob(posterior2["mu"])
logprob_sigma = dist.Uniform(0, 50).log_prob(posterior2["sigma"])
posterior2["prob"] = posterior2["LL"] + logprob_mu + logprob_sigma
posterior2["prob"] = jnp.exp(posterior2["prob"] - jnp.max(posterior2["prob"]))
prob = posterior2["prob"] / jnp.sum(posterior2["prob"])
sample2_rows = dist.Categorical(probs=prob).sample(jrng, (int(1e4),))
sample2_mu = posterior2["mu"][sample2_rows]
sample2_sigma = posterior2["sigma"][sample2_rows]
plt.scatter(sample2_mu, sample2_sigma, s=64, alpha=0.1, edgecolor="none")
plt.show()
Code 4.25¶
[31]:
az.plot_kde(sample2_mu)
x = jnp.sort(sample2_mu)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
[32]:
az.plot_kde(sample2_sigma)
x = jnp.sort(sample2_sigma)
plt.plot(x, jnp.exp(dist.Normal(jnp.mean(x), jnp.std(x)).log_prob(x)), "--")
plt.show()
Code 4.26¶
[33]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
df = Howell1
df2 = df[df["age"] >= 18]
Code 4.27¶
[34]:
def adult_height_model(height, priors):
mu = numpyro.sample(
"mu", dist.Normal(loc=priors["mu_mean"], scale=priors["mu_scale"])
)
sigma = numpyro.sample(
"sigma", dist.Uniform(low=priors["sigma_low"], high=priors["sigma_high"])
)
numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
Code 4.28¶
[35]:
adult_height_laplace_model = AutoLaplaceApproximation(adult_height_model)
adult_height_svi = SVI(
model=adult_height_model,
guide=adult_height_laplace_model,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
height=df2.height.values,
priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 5_000)
adult_height_svi.params
100%|██████████████████| 5000/5000 [00:01<00:00, 3658.83it/s, init loss: 28358.6328, avg. loss [4751-5000]: 1226.0383]
[35]:
{'auto_loc': Array([154.60709 , -1.6973906], dtype=float32)}
Code 4.29¶
[36]:
_, _jrng = jax.random.split(jrng)
samples = adult_height_laplace_model.sample_posterior(
_jrng, adult_height_svi.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat
mu 154.61 0.40 154.60 153.98 155.25 977.98 1.00
sigma 7.74 0.30 7.73 7.31 8.23 1044.95 1.00
Code 4.30¶
[37]:
init_fn_values = {"mu": df2["height"].mean(), "sigma": df2["height"].std()}
adult_height_laplace_model = AutoLaplaceApproximation(
adult_height_model, init_loc_fn=numpyro.infer.init_to_value(values=init_fn_values)
)
adult_height_svi = SVI(
model=adult_height_model,
guide=adult_height_laplace_model,
optim=numpyro.optim.Adam(step_size=0.01),
loss=Trace_ELBO(),
height=df2.height.values,
priors={"mu_mean": 178, "mu_scale": 20, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 5000)
adult_height_svi.params
100%|███████████████████| 5000/5000 [00:01<00:00, 3099.28it/s, init loss: 1226.0387, avg. loss [4751-5000]: 1226.0383]
[37]:
{'auto_loc': Array([154.60728 , -1.6970557], dtype=float32)}
[38]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model.sample_posterior(
_jrng, adult_height_svi.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat
mu 154.59 0.40 154.59 153.99 155.26 965.98 1.00
sigma 7.76 0.29 7.76 7.28 8.20 892.76 1.00
Code 4.31¶
[39]:
adult_height_laplace_model_2 = AutoLaplaceApproximation(adult_height_model)
adult_height_svi_2 = SVI(
model=adult_height_model,
guide=adult_height_laplace_model_2,
optim=numpyro.optim.Adam(1),
loss=Trace_ELBO(),
height=df2.height.values,
priors={"mu_mean": 178, "mu_scale": 0.1, "sigma_low": 0, "sigma_high": 50},
).run(jrng, 2000)
adult_height_svi_2.params
100%|████████████████| 2000/2000 [00:00<00:00, 2592.33it/s, init loss: 1619491.5000, avg. loss [1901-2000]: 1626.5827]
[39]:
{'auto_loc': Array([ 1.7786377e+02, -3.8493071e-02], dtype=float32)}
[40]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model_2.sample_posterior(
_jrng, adult_height_svi_2.params, sample_shape=(1000,)
)
numpyro.diagnostics.print_summary(samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat
mu 177.86 0.10 177.86 177.69 178.02 909.28 1.00
sigma 24.51 0.94 24.48 23.16 26.11 945.10 1.00
Code 4.32¶
[41]:
_, _jrng = jax.random.split(_jrng)
samples = adult_height_laplace_model.sample_posterior(
_jrng, adult_height_svi.params, sample_shape=(1000,)
)
vcov = pd.DataFrame(jnp.stack(list(samples.values())), index=["mu", "sigma"]).T.cov()
vcov
[41]:
| mu | sigma | |
|---|---|---|
| mu | 0.162994 | -0.002029 |
| sigma | -0.002029 | 0.086252 |
Code 4.33¶
[42]:
print(jnp.diagonal(vcov.values) ** 0.5)
print(
vcov.values
/ jnp.sqrt(jnp.outer(jnp.diagonal(vcov.values), jnp.diagonal(vcov.values)))
)
[0.40372518 0.2936876 ]
[[ 1. -0.0171095]
[-0.0171095 1. ]]
Code 4.34¶
[43]:
_, _jrng = jax.random.split(_jrng)
samples = pd.DataFrame(
adult_height_laplace_model.sample_posterior(
_jrng, adult_height_svi.params, sample_shape=(10_000,)
)
)
samples.head()
[43]:
| mu | sigma | |
|---|---|---|
| 0 | 154.254120 | 8.097201 |
| 1 | 154.200668 | 7.438225 |
| 2 | 153.919067 | 7.663988 |
| 3 | 154.778809 | 7.601790 |
| 4 | 154.752045 | 7.663299 |
Code 4.35¶
[44]:
samples.describe()
[44]:
| mu | sigma | |
|---|---|---|
| count | 10000.000000 | 10000.000000 |
| mean | 154.605774 | 7.747313 |
| std | 0.413330 | 0.290448 |
| min | 152.837646 | 6.659805 |
| 25% | 154.328228 | 7.547918 |
| 50% | 154.603020 | 7.743933 |
| 75% | 154.886684 | 7.937230 |
| max | 156.360703 | 8.950572 |
Code 4.36¶
[45]:
# don't know how to interpret adult_height_svi.params
# feels like it should be the MAP of sigma, but clearly not (it has wrong sign)
# below is kinda dumb workaround (get the samples from .sample_posterior in order to generate samples)
samples = pd.DataFrame(
adult_height_laplace_model.sample_posterior(
_jrng, adult_height_svi.params, sample_shape=(10_000,)
)
)
vcov = samples.cov()
samples = pd.DataFrame(
dist.MultivariateNormal(
loc=samples.mean().values, covariance_matrix=vcov.values
).sample(_jrng, sample_shape=(10_000,)),
columns=["mu", "sigma"],
)
samples
[45]:
| mu | sigma | |
|---|---|---|
| 0 | 154.251968 | 8.095943 |
| 1 | 154.198425 | 7.442693 |
| 2 | 153.916321 | 7.673726 |
| 3 | 154.777603 | 7.605191 |
| 4 | 154.750793 | 7.667237 |
| ... | ... | ... |
| 9995 | 154.689835 | 7.479685 |
| 9996 | 154.250900 | 7.803380 |
| 9997 | 154.720367 | 7.816793 |
| 9998 | 154.711197 | 8.014578 |
| 9999 | 154.432938 | 7.873176 |
10000 rows × 2 columns
Code 4.37¶
[46]:
df2.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
[46]:
<Axes: xlabel='weight', ylabel='height'>
Code 4.38¶
[47]:
_, _jrng = jax.random.split(_jrng)
a = dist.Normal(loc=178, scale=20).sample(jrng, sample_shape=(100,))
_, _jrng = jax.random.split(_jrng)
b = dist.Normal(loc=0, scale=10).sample(_jrng, sample_shape=(100,))
[48]:
def adult_height_model(
height, weight, *, average_weight, alpha_prior, beta_prior, sigma_prior
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.Normal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
forecast = numpyro.deterministic(
"forecast", alpha + beta * (weight - average_weight)
)
height = numpyro.sample(
"height", dist.Normal(loc=forecast, scale=sigma), obs=height
)
return height
[49]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=100)(
jrng,
height=None,
weight=df2["weight"].values,
average_weight=df2["weight"].mean(),
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 10},
sigma_prior={"low": 0, "high": 50},
)
Code 4.39¶
[50]:
def plot_prior_lines(a, b):
plt.subplot(
xlim=(df2.weight.min(), df2.weight.max()),
ylim=(-100, 400),
xlabel="weight",
ylabel="height",
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("b ~ Normal(0, 10)")
xbar = df2.weight.mean()
x = jnp.linspace(df2.weight.min(), df2.weight.max(), 101)
for i in range(100):
plt.plot(
x,
a[i] + b[i] * (x - xbar),
"k",
alpha=0.2,
)
plt.show()
plot_prior_lines(a=prior_samples["alpha"], b=prior_samples["beta"])
Code 4.40¶
[51]:
b = dist.LogNormal(loc=0, scale=1).sample(jrng, sample_shape=(10_000,))
az.plot_kde(b)
[51]:
<Axes: >
Code 4.41¶
[52]:
def adult_height_model(
height, weight, *, average_weight, alpha_prior, beta_prior, sigma_prior
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
forecast = numpyro.deterministic(
"forecast", alpha + beta * (weight - average_weight)
)
height = numpyro.sample(
"height", dist.Normal(loc=forecast, scale=sigma), obs=height
)
return height
[53]:
prior_samples = numpyro.infer.Predictive(adult_height_model, num_samples=100)(
jrng,
height=None,
weight=df2["weight"].values,
average_weight=df2["weight"].mean(),
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
)
plot_prior_lines(a=prior_samples["alpha"], b=prior_samples["beta"])
Code 4.42¶
[54]:
Howell1 = pd.read_csv("../data/Howell1.csv", sep=";")
df = Howell1
df2 = df[df["age"] >= 18]
average_weight = df2["weight"].mean()
def adult_height_model(
weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
# mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
mu = alpha + beta * (weight - average_weight)
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
model=adult_height_model,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight=df2["weight"].values,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=df2["height"].values,
).run(jrng, 5_000)
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████████| 5000/5000 [00:01<00:00, 3308.69it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]
mean std median 5.5% 94.5% n_eff r_hat
alpha 154.61 0.27 154.62 154.23 155.08 1000.98 1.00
beta 0.91 0.04 0.91 0.84 0.97 1109.19 1.00
sigma 5.09 0.20 5.09 4.81 5.43 773.71 1.00
Code 4.43¶
[55]:
def adult_height_model(
weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
# mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
mu = alpha + beta * (weight - average_weight)
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
model=adult_height_model,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight=df2["weight"].values,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=df2["height"].values,
).run(jrng, 5_000)
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████████| 5000/5000 [00:01<00:00, 3371.98it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]
mean std median 5.5% 94.5% n_eff r_hat
alpha 154.60 0.27 154.61 154.20 155.07 916.04 1.00
beta 0.91 0.04 0.91 0.84 0.97 892.87 1.00
sigma 5.08 0.19 5.08 4.79 5.40 979.48 1.00
Code 4.44¶
[56]:
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat
alpha 154.60 0.27 154.61 154.20 155.07 916.04 1.00
beta 0.91 0.04 0.91 0.84 0.97 892.87 1.00
sigma 5.08 0.19 5.08 4.79 5.40 979.48 1.00
Code 4.45¶
[57]:
pd.DataFrame(posterior_samples).cov().round(3)
[57]:
| alpha | beta | sigma | |
|---|---|---|---|
| alpha | 0.073 | -0.001 | -0.002 |
| beta | -0.001 | 0.002 | 0.000 |
| sigma | -0.002 | 0.000 | 0.036 |
Code 4.46¶
[58]:
fig = pd.DataFrame(df2[["weight", "height"]]).plot(
kind="scatter",
x="weight",
y="height",
)
x = jnp.linspace(df2["weight"].min() * 0.95, df2["weight"].max() * 1.05)
y = posterior_samples["alpha"].mean() + posterior_samples["beta"].mean() * (
x - average_weight
)
fig.add_trace(go.Scatter(x=x, y=y, name="posterior_mean"))
Code 4.47¶
[59]:
pd.DataFrame(posterior_samples).head()
[59]:
| alpha | beta | sigma | |
|---|---|---|---|
| 0 | 154.845612 | 0.837958 | 5.169509 |
| 1 | 154.411041 | 0.881906 | 5.131563 |
| 2 | 154.510010 | 0.946982 | 5.246259 |
| 3 | 154.651245 | 0.865017 | 5.103053 |
| 4 | 154.748108 | 0.853965 | 4.906911 |
Code 4.48¶
[60]:
def adult_height_model(
weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
# mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
mu = alpha + beta * (weight - average_weight)
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
model=adult_height_model,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight=df2["weight"].values[:10],
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=df2["height"].values[:10],
).run(jrng, 5_000)
100%|██████████████████████| 5000/5000 [00:01<00:00, 3435.39it/s, init loss: 194.4541, avg. loss [4751-5000]: 37.0885]
Code 4.49¶
[61]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(20,))
numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
fig = pd.DataFrame(df2[["weight", "height"]].iloc[:10]).plot(
kind="scatter",
x="weight",
y="height",
)
x = jnp.linspace(df2["weight"].min() * 0.95, df2["weight"].max() * 1.05)
for i in range(20):
y = posterior_samples["alpha"][i] + posterior_samples["beta"][i] * (
x - average_weight
)
fig.add_trace(
go.Scatter(x=x, y=y, line={"color": "black"}, opacity=0.3, showlegend=False)
)
fig
mean std median 5.5% 94.5% n_eff r_hat
alpha 152.06 1.38 152.26 150.57 154.32 20.59 0.95
beta 0.96 0.15 0.92 0.76 1.21 21.22 1.13
sigma 4.49 1.38 4.37 3.11 5.50 38.15 0.96
Code 4.50¶
[62]:
def adult_height_model(
weight, *, average_weight, alpha_prior, beta_prior, sigma_prior, height=None
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
mu = numpyro.deterministic("mu", alpha + beta * (weight - average_weight))
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(adult_height_model)
svi = SVI(
model=adult_height_model,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight=df2["weight"].values,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=df2["height"].values,
).run(jrng, 5_000)
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
# numpyro.diagnostics.print_summary(posterior_samples, 0.89, False)
100%|███████████████████| 5000/5000 [00:01<00:00, 3199.73it/s, init loss: 5353.8887, avg. loss [4751-5000]: 1078.9313]
[63]:
mu_at_50 = posterior_samples["alpha"] + posterior_samples["beta"] * (
50 - average_weight
)
Code 4.51¶
[64]:
az.plot_kde(mu_at_50, label="mu|weight=50")
[64]:
<Axes: >
Code 4.52¶
[65]:
numpyro.diagnostics.hpdi(mu_at_50, prob=0.89)
[65]:
array([158.57683, 159.64386], dtype=float32)
Code 4.53¶
[66]:
mu = pd.DataFrame(posterior_samples["mu"])
mu.columns.name = "training sample"
mu.index.name = "posterior predictive sample"
mu
[66]:
| training sample | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| posterior predictive sample | |||||||||||||||||||||
| 0 | 157.304321 | 146.837982 | 142.572968 | 162.118820 | 151.260010 | 171.303024 | 148.460266 | 164.369080 | 145.346542 | 163.453278 | ... | 153.798096 | 157.356644 | 149.533066 | 151.050690 | 150.841354 | 156.571671 | 144.770889 | 161.307693 | 163.060791 | 161.647842 |
| 1 | 156.541153 | 146.848007 | 142.898056 | 161.000000 | 150.943359 | 169.505722 | 148.350449 | 163.084015 | 145.466736 | 162.235870 | ... | 153.293945 | 156.589615 | 149.343994 | 150.749496 | 150.555634 | 155.862625 | 144.933624 | 160.248779 | 161.872375 | 160.563812 |
| 2 | 157.259293 | 147.320450 | 143.270386 | 161.831146 | 151.519608 | 170.552475 | 148.860977 | 163.968002 | 145.904175 | 163.098343 | ... | 153.929779 | 157.308975 | 149.879700 | 151.320831 | 151.122055 | 156.563568 | 145.357544 | 161.060883 | 162.725647 | 161.383896 |
| 3 | 157.621613 | 146.588577 | 142.092621 | 162.696808 | 151.250031 | 172.378296 | 148.298706 | 165.068909 | 145.016373 | 164.103516 | ... | 153.925552 | 157.676773 | 149.429581 | 151.029373 | 150.808716 | 156.849304 | 144.409561 | 161.841751 | 163.689789 | 162.200317 |
| 4 | 157.061279 | 146.611618 | 142.353378 | 161.868118 | 151.026596 | 171.037689 | 148.231308 | 164.114792 | 145.122528 | 163.200455 | ... | 153.560638 | 157.113525 | 149.302399 | 150.817596 | 150.608612 | 156.329803 | 144.547806 | 161.058273 | 162.808594 | 161.397888 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | 157.155518 | 146.734146 | 142.487442 | 161.949341 | 151.137177 | 171.094101 | 148.349457 | 164.189941 | 145.249100 | 163.278076 | ... | 153.664352 | 157.207626 | 149.417648 | 150.928757 | 150.720322 | 156.426025 | 144.675934 | 161.141693 | 162.887268 | 161.480377 |
| 996 | 156.507141 | 146.769485 | 142.801392 | 160.986465 | 150.883652 | 169.531265 | 148.278824 | 163.080063 | 145.381866 | 162.228027 | ... | 153.245026 | 156.555832 | 149.276932 | 150.688889 | 150.494141 | 155.825516 | 144.846298 | 160.231812 | 161.862869 | 160.548279 |
| 997 | 157.638733 | 146.461838 | 141.907242 | 162.780121 | 151.184082 | 172.587845 | 148.194260 | 165.183151 | 144.869125 | 164.205170 | ... | 153.894470 | 157.694626 | 149.339890 | 150.960541 | 150.737000 | 156.856354 | 144.254395 | 161.913910 | 163.786041 | 162.277161 |
| 998 | 157.329163 | 147.508148 | 143.506088 | 161.846832 | 151.657532 | 170.464767 | 149.030411 | 163.958344 | 146.108658 | 163.099014 | ... | 154.039124 | 157.378265 | 150.037064 | 151.461105 | 151.264694 | 156.641693 | 145.568497 | 161.085709 | 162.730728 | 161.404892 |
| 999 | 156.900024 | 146.914627 | 142.845581 | 161.493301 | 151.133453 | 170.255478 | 148.462357 | 163.640152 | 145.491699 | 162.766434 | ... | 153.554916 | 156.949951 | 149.485870 | 150.933746 | 150.734039 | 156.201035 | 144.942505 | 160.719437 | 162.391983 | 161.043961 |
1000 rows × 352 columns
Code 4.54¶
[67]:
_, _jrng = jax.random.split(_jrng)
weight_seq = jnp.arange(start=25, stop=71, step=1)
mu = numpyro.infer.Predictive(guide.model, posterior_samples, return_sites=["mu"])(
_jrng,
weight=weight_seq,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=None,
)["mu"]
mu = pd.DataFrame(mu, columns=pd.Index(weight_seq, name="weight"))
mu.index.name = "posterior predictive sample"
assert (
posterior_samples["alpha"][0] + posterior_samples["beta"][0] * (25 - average_weight)
== mu.iat[0, 0]
)
mu
[67]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| posterior predictive sample | |||||||||||||||||||||
| 0 | 136.236908 | 137.159882 | 138.082840 | 139.005814 | 139.928787 | 140.851761 | 141.774734 | 142.697708 | 143.620682 | 144.543655 | ... | 169.463928 | 170.386902 | 171.309860 | 172.232834 | 173.155807 | 174.078781 | 175.001755 | 175.924728 | 176.847702 | 177.770676 |
| 1 | 137.030075 | 137.884857 | 138.739655 | 139.594437 | 140.449234 | 141.304016 | 142.158798 | 143.013596 | 143.868378 | 144.723175 | ... | 167.802475 | 168.657272 | 169.512054 | 170.366852 | 171.221634 | 172.076431 | 172.931213 | 173.786011 | 174.640793 | 175.495590 |
| 2 | 137.253662 | 138.130112 | 139.006577 | 139.883026 | 140.759476 | 141.635941 | 142.512390 | 143.388840 | 144.265305 | 145.141754 | ... | 168.806061 | 169.682510 | 170.558960 | 171.435425 | 172.311874 | 173.188339 | 174.064789 | 174.941238 | 175.817703 | 176.694153 |
| 3 | 135.413483 | 136.386444 | 137.359390 | 138.332336 | 139.305283 | 140.278229 | 141.251175 | 142.224121 | 143.197067 | 144.170013 | ... | 170.439606 | 171.412552 | 172.385513 | 173.358459 | 174.331406 | 175.304352 | 176.277298 | 177.250244 | 178.223190 | 179.196136 |
| 4 | 136.027405 | 136.948914 | 137.870407 | 138.791916 | 139.713409 | 140.634918 | 141.556427 | 142.477921 | 143.399429 | 144.320938 | ... | 169.201523 | 170.123016 | 171.044525 | 171.966034 | 172.887527 | 173.809036 | 174.730530 | 175.652039 | 176.573547 | 177.495041 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | 136.178604 | 137.097610 | 138.016617 | 138.935623 | 139.854630 | 140.773636 | 141.692642 | 142.611664 | 143.530670 | 144.449677 | ... | 169.262894 | 170.181900 | 171.100906 | 172.019913 | 172.938919 | 173.857925 | 174.776932 | 175.695953 | 176.614960 | 177.533966 |
| 996 | 136.906448 | 137.765167 | 138.623871 | 139.482590 | 140.341309 | 141.200027 | 142.058731 | 142.917450 | 143.776169 | 144.634888 | ... | 167.820206 | 168.678925 | 169.537628 | 170.396347 | 171.255066 | 172.113785 | 172.972504 | 173.831207 | 174.689926 | 175.548645 |
| 997 | 135.141022 | 136.126648 | 137.112289 | 138.097931 | 139.083557 | 140.069199 | 141.054825 | 142.040466 | 143.026108 | 144.011734 | ... | 170.623886 | 171.609528 | 172.595154 | 173.580795 | 174.566422 | 175.552063 | 176.537689 | 177.523331 | 178.508972 | 179.494598 |
| 998 | 137.560684 | 138.426758 | 139.292816 | 140.158890 | 141.024948 | 141.891022 | 142.757080 | 143.623154 | 144.489212 | 145.355286 | ... | 168.739059 | 169.605118 | 170.471191 | 171.337250 | 172.203323 | 173.069382 | 173.935455 | 174.801514 | 175.667587 | 176.533646 |
| 999 | 136.800659 | 137.681229 | 138.561783 | 139.442352 | 140.322906 | 141.203476 | 142.084030 | 142.964600 | 143.845154 | 144.725723 | ... | 168.500885 | 169.381439 | 170.262009 | 171.142563 | 172.023132 | 172.903687 | 173.784256 | 174.664810 | 175.545380 | 176.425934 |
1000 rows × 46 columns
Code 4.55¶
[68]:
df2[["weight", "height"]].plot(kind="scatter", x="weight", y="height", opacity=0)
for i in range(100):
plt.plot(weight_seq, mu.values[i], "o", c="royalblue", alpha=0.1)
Code 4.56¶
[69]:
mu_mean = mu.mean().to_frame().T
mu_mean
[69]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 136.505142 | 137.410141 | 138.315186 | 139.220139 | 140.125198 | 141.030197 | 141.93512 | 142.84024 | 143.745255 | 144.650284 | ... | 169.085724 | 169.990814 | 170.895859 | 171.800781 | 172.705841 | 173.610886 | 174.515793 | 175.420837 | 176.325867 | 177.23085 |
1 rows × 46 columns
[70]:
mu_hpdi = mu.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
mu_hpdi
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.
[70]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 135.026993 | 136.026367 | 136.997055 | 137.951721 | 138.926727 | 139.908661 | 140.870804 | 141.861160 | 142.846619 | 143.793472 | ... | 167.961838 | 168.830612 | 169.670883 | 170.516129 | 171.353882 | 172.191635 | 173.028885 | 173.860092 | 174.689926 | 175.517319 |
| 1 | 137.909683 | 138.774658 | 139.614487 | 140.431335 | 141.287354 | 142.156265 | 142.993195 | 143.865662 | 144.713287 | 145.557938 | ... | 170.256363 | 171.241119 | 172.204147 | 173.178024 | 174.145081 | 175.113068 | 176.089157 | 177.051392 | 178.017044 | 178.985031 |
2 rows × 46 columns
Code 4.57¶
[71]:
ax = df2[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
[71]:
<matplotlib.collections.PolyCollection at 0x7d874df67ce0>
Code 4.58¶
[72]:
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
posterior_samples.pop("mu")
posterior_samples = pd.DataFrame(posterior_samples)
def mu_link(weight):
return posterior_samples["alpha"] + posterior_samples["beta"] * (
weight - average_weight
)
mu = pd.concat([mu_link(_weight) for _weight in weight_seq], axis=1)
mu.columns = pd.Index(weight_seq, name="weight")
mu.index.name = "posterior predictive sample"
mu
[72]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| posterior predictive sample | |||||||||||||||||||||
| 0 | 135.985901 | 136.901230 | 137.816559 | 138.731888 | 139.647217 | 140.562546 | 141.477875 | 142.393219 | 143.308548 | 144.223877 | ... | 168.937790 | 169.853119 | 170.768448 | 171.683777 | 172.599121 | 173.514450 | 174.429779 | 175.345108 | 176.260437 | 177.175766 |
| 1 | 137.509995 | 138.359192 | 139.208389 | 140.057587 | 140.906784 | 141.755981 | 142.605164 | 143.454361 | 144.303558 | 145.152756 | ... | 168.081024 | 168.930222 | 169.779419 | 170.628616 | 171.477814 | 172.327011 | 173.176208 | 174.025391 | 174.874588 | 175.723785 |
| 2 | 135.088242 | 136.062073 | 137.035889 | 138.009720 | 138.983551 | 139.957382 | 140.931213 | 141.905029 | 142.878860 | 143.852692 | ... | 170.146027 | 171.119858 | 172.093689 | 173.067520 | 174.041336 | 175.015167 | 175.988998 | 176.962830 | 177.936646 | 178.910477 |
| 3 | 134.468567 | 135.461731 | 136.454910 | 137.448090 | 138.441254 | 139.434433 | 140.427612 | 141.420792 | 142.413956 | 143.407135 | ... | 170.222870 | 171.216049 | 172.209229 | 173.202393 | 174.195572 | 175.188751 | 176.181915 | 177.175095 | 178.168274 | 179.161453 |
| 4 | 136.359711 | 137.267639 | 138.175568 | 139.083481 | 139.991409 | 140.899338 | 141.807251 | 142.715179 | 143.623108 | 144.531036 | ... | 169.044968 | 169.952896 | 170.860809 | 171.768738 | 172.676666 | 173.584595 | 174.492508 | 175.400436 | 176.308365 | 177.216278 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | 136.797119 | 137.712357 | 138.627609 | 139.542847 | 140.458099 | 141.373337 | 142.288589 | 143.203827 | 144.119080 | 145.034332 | ... | 169.745956 | 170.661209 | 171.576462 | 172.491699 | 173.406952 | 174.322189 | 175.237442 | 176.152679 | 177.067932 | 177.983170 |
| 996 | 135.942429 | 136.874466 | 137.806503 | 138.738525 | 139.670563 | 140.602600 | 141.534637 | 142.466660 | 143.398697 | 144.330734 | ... | 169.495636 | 170.427673 | 171.359711 | 172.291733 | 173.223770 | 174.155807 | 175.087845 | 176.019867 | 176.951904 | 177.883942 |
| 997 | 136.452545 | 137.376831 | 138.301117 | 139.225418 | 140.149719 | 141.074005 | 141.998291 | 142.922592 | 143.846878 | 144.771179 | ... | 169.727097 | 170.651398 | 171.575684 | 172.499985 | 173.424271 | 174.348572 | 175.272858 | 176.197144 | 177.121445 | 178.045746 |
| 998 | 136.993576 | 137.869659 | 138.745728 | 139.621811 | 140.497894 | 141.373978 | 142.250061 | 143.126144 | 144.002228 | 144.878296 | ... | 168.532501 | 169.408585 | 170.284668 | 171.160751 | 172.036835 | 172.912918 | 173.789001 | 174.665070 | 175.541153 | 176.417236 |
| 999 | 135.966492 | 136.885849 | 137.805206 | 138.724579 | 139.643936 | 140.563309 | 141.482666 | 142.402039 | 143.321396 | 144.240753 | ... | 169.063583 | 169.982941 | 170.902313 | 171.821671 | 172.741043 | 173.660400 | 174.579773 | 175.499130 | 176.418488 | 177.337860 |
1000 rows × 46 columns
Code 4.59¶
[73]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
weight_seq = jnp.arange(start=25, stop=71, step=1)
height = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["height"]
)(
_jrng,
weight=weight_seq,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=None,
)[
"height"
]
height = pd.DataFrame(height, columns=pd.Index(weight_seq, name="weight"))
height.index.name = "posterior predictive height sample"
height
[73]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| posterior predictive height sample | |||||||||||||||||||||
| 0 | 136.508286 | 149.435226 | 135.722305 | 149.388138 | 139.763763 | 139.949677 | 140.646118 | 138.855057 | 142.752548 | 143.562958 | ... | 169.097351 | 173.265259 | 164.546005 | 167.680725 | 166.906647 | 164.859055 | 172.325638 | 177.131287 | 174.472702 | 171.804504 |
| 1 | 137.130280 | 131.644226 | 131.078796 | 146.600143 | 134.685349 | 142.405029 | 145.415344 | 141.809189 | 143.663727 | 150.538666 | ... | 171.158493 | 173.759567 | 166.737335 | 173.044586 | 169.006485 | 182.760620 | 178.568680 | 176.966095 | 173.372452 | 167.617493 |
| 2 | 140.201584 | 140.972473 | 138.505524 | 138.571991 | 145.531540 | 137.499023 | 151.230240 | 136.424301 | 141.453522 | 144.187225 | ... | 165.414032 | 174.831863 | 168.857559 | 164.795746 | 169.127579 | 177.319336 | 182.095734 | 170.876938 | 182.959732 | 178.725433 |
| 3 | 137.223373 | 138.317673 | 135.861816 | 141.323700 | 144.002625 | 141.397507 | 146.272064 | 141.174606 | 144.210037 | 150.681000 | ... | 165.021179 | 174.404190 | 171.814453 | 173.515961 | 168.846237 | 175.028915 | 176.868759 | 169.214584 | 163.540894 | 176.363586 |
| 4 | 134.241409 | 144.002289 | 140.749893 | 137.863724 | 141.580338 | 139.624512 | 135.989700 | 139.275009 | 140.649887 | 138.531845 | ... | 167.624329 | 169.881714 | 172.509033 | 160.924622 | 168.618073 | 169.110428 | 183.953873 | 183.501358 | 181.296814 | 176.981949 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | 141.023254 | 140.618423 | 148.049911 | 143.028687 | 135.105743 | 136.967056 | 142.688721 | 145.131409 | 145.675903 | 143.714188 | ... | 161.923935 | 166.959091 | 167.248596 | 160.272964 | 170.994476 | 170.879868 | 175.324860 | 184.061798 | 176.259384 | 177.054688 |
| 996 | 136.336578 | 138.168930 | 135.906570 | 144.396240 | 139.384766 | 147.582108 | 133.617355 | 143.931732 | 146.498352 | 142.418411 | ... | 159.824448 | 166.302322 | 163.272964 | 166.170166 | 166.100311 | 174.871994 | 173.713211 | 170.594254 | 167.849503 | 179.801788 |
| 997 | 128.789352 | 137.093933 | 141.240921 | 131.618515 | 128.948471 | 133.299667 | 146.318787 | 131.432816 | 147.747055 | 143.982483 | ... | 172.134598 | 169.760605 | 170.733414 | 175.537643 | 174.765945 | 180.729172 | 179.618698 | 169.870499 | 170.286469 | 179.472046 |
| 998 | 126.851830 | 134.977737 | 135.001389 | 131.294464 | 147.969147 | 137.543365 | 146.132797 | 140.221573 | 146.279465 | 139.622467 | ... | 166.252594 | 165.666428 | 172.711197 | 175.712952 | 176.767441 | 178.774521 | 177.582687 | 180.221420 | 176.854004 | 167.835709 |
| 999 | 135.678772 | 135.806549 | 145.103104 | 138.308670 | 141.327972 | 148.940582 | 133.575974 | 145.009354 | 141.294922 | 147.490875 | ... | 167.654053 | 172.131699 | 167.774033 | 174.673126 | 173.745056 | 175.865540 | 175.854340 | 176.483749 | 174.906326 | 168.637466 |
1000 rows × 46 columns
Code 4.60¶
[74]:
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
height_hpdi
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.
[74]:
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 129.079712 | 128.893143 | 129.459946 | 130.970932 | 131.758331 | 132.952377 | 133.561066 | 134.395966 | 136.13765 | 137.199585 | ... | 160.902451 | 161.858826 | 163.364090 | 162.511414 | 165.133896 | 166.108643 | 166.287201 | 167.202789 | 168.232437 | 168.767731 |
| 1 | 144.476105 | 145.375854 | 145.479507 | 147.659012 | 148.286728 | 149.444717 | 149.126480 | 150.059341 | 151.34671 | 152.877716 | ... | 177.465652 | 178.126465 | 179.430649 | 178.957062 | 181.207123 | 182.581741 | 182.563736 | 182.397293 | 184.117691 | 185.513077 |
2 rows × 46 columns
Code 4.61¶
[75]:
ax = df2[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
[75]:
<matplotlib.collections.PolyCollection at 0x7d874de182f0>
Code 4.62¶
[76]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.arange(start=25, stop=71, step=1)
height = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["height"]
)(
_jrng,
weight=weight_seq,
average_weight=average_weight,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=None,
)[
"height"
]
height = pd.DataFrame(height, columns=pd.Index(weight_seq, name="weight"))
height.index.name = "posterior predictive height sample"
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
display(height_hpdi)
ax = df2[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 128.198639 | 129.154419 | 130.445114 | 130.696640 | 131.987961 | 133.248535 | 133.680359 | 134.982620 | 135.565125 | 136.709518 | ... | 160.880524 | 161.950745 | 162.872162 | 163.943100 | 164.586792 | 165.942871 | 166.361206 | 167.636047 | 168.283356 | 169.236526 |
| 1 | 144.747787 | 146.053528 | 146.836273 | 147.281357 | 148.310165 | 149.428192 | 150.228561 | 151.407837 | 151.797745 | 152.947876 | ... | 177.590027 | 178.494415 | 179.222809 | 180.173004 | 180.973145 | 182.395248 | 182.904404 | 183.754227 | 184.861252 | 185.728287 |
2 rows × 46 columns
[76]:
<matplotlib.collections.PolyCollection at 0x7d874dd68d40>
Code 4.63¶
[77]:
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
posterior_samples.pop("mu")
posterior_samples = pd.DataFrame(posterior_samples)
def sim_height(weight, jrng):
mu = posterior_samples["alpha"] + posterior_samples["beta"] * (
weight - average_weight
)
return dist.Normal(loc=mu, scale=posterior_samples["sigma"]).sample(
jrng,
)
height = []
for _weight in weight_seq:
_, jrng = jax.random.split(jrng)
height.append(sim_height(_weight, _jrng))
height = pd.concat(height, axis=1)
height.columns = pd.Index(weight_seq, name="weight")
height.index.name = "posterior predictive sample"
height_hpdi = height.apply(lambda x: numpyro.diagnostics.hpdi(x, prob=0.89))
display(height_hpdi)
ax = df2[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib", alpha=0.5
)
plt.plot(weight_seq, mu_mean.T, "k-")
plt.fill_between(weight_seq, mu_hpdi.iloc[0], mu_hpdi.iloc[1], color="k", alpha=0.4)
plt.fill_between(
weight_seq, height_hpdi.iloc[0], height_hpdi.iloc[1], color="k", alpha=0.2
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'Series.swapaxes' is deprecated and will be removed in a future version. Please use 'Series.transpose' instead.
| weight | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | ... | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 128.353043 | 129.231384 | 130.064392 | 130.897827 | 131.913391 | 132.745377 | 133.669144 | 134.482864 | 135.470245 | 136.378235 | ... | 161.073730 | 161.500427 | 162.376999 | 163.222519 | 164.476730 | 164.565491 | 165.759048 | 166.797089 | 167.679825 | 168.559372 |
| 1 | 143.963165 | 144.840317 | 145.646332 | 146.425797 | 147.388672 | 148.204834 | 149.118210 | 149.980774 | 150.996918 | 151.936279 | ... | 177.229584 | 177.667191 | 178.607346 | 179.467682 | 180.740768 | 180.852631 | 182.037155 | 183.107285 | 183.987915 | 184.907028 |
2 rows × 46 columns
[77]:
<matplotlib.collections.PolyCollection at 0x7d870d913a70>
Code 4.64¶
[78]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df
[78]:
| height | weight | age | male | |
|---|---|---|---|---|
| 0 | 151.765 | 47.825606 | 63.0 | 1 |
| 1 | 139.700 | 36.485807 | 63.0 | 0 |
| 2 | 136.525 | 31.864838 | 65.0 | 0 |
| 3 | 156.845 | 53.041914 | 41.0 | 1 |
| 4 | 145.415 | 41.276872 | 51.0 | 0 |
| ... | ... | ... | ... | ... |
| 539 | 145.415 | 31.127751 | 17.0 | 1 |
| 540 | 162.560 | 52.163080 | 31.0 | 1 |
| 541 | 156.210 | 54.062497 | 21.0 | 0 |
| 542 | 71.120 | 8.051258 | 0.0 | 1 |
| 543 | 158.750 | 52.531624 | 68.0 | 1 |
544 rows × 4 columns
[79]:
df[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib"
)
[79]:
<Axes: xlabel='weight', ylabel='height'>
Code 4.65¶
[80]:
df["weight_s"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()
df["weight_s2"] = df["weight_s"] ** 2
def m4_5(
weight_s,
weight_s2,
*,
alpha_prior={"loc": 178, "scale": 20},
beta1_prior={"loc": 0, "scale": 1},
beta2_prior={"loc": 0, "scale": 1},
sigma_prior={"low": 0, "high": 50},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta1 = numpyro.sample("beta1", dist.LogNormal(**beta1_prior))
beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
mu = numpyro.deterministic("mu", alpha + beta1 * weight_s + beta2 * weight_s2)
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(m4_5)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=m4_5,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight_s=df["weight_s"].values,
weight_s2=df["weight_s2"].values,
height=df["height"].values,
).run(jrng, 5_000)
100%|███████████████████| 5000/5000 [00:01<00:00, 3524.35it/s, init loss: 8877.6973, avg. loss [4751-5000]: 1770.2781]
Code 4.66¶
[81]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(1000,))
_posterior_samples = {
k: posterior_samples[k] for k in posterior_samples if k not in "mu"
}
numpyro.diagnostics.print_summary(_posterior_samples, 0.89, False)
mean std median 5.5% 94.5% n_eff r_hat
alpha 146.05 0.38 146.06 145.49 146.68 932.87 1.00
beta1 21.73 0.30 21.72 21.26 22.19 903.86 1.00
beta2 -7.79 0.28 -7.79 -8.21 -7.30 940.81 1.00
sigma 5.78 0.19 5.78 5.44 6.02 1084.68 1.00
Code 4.67¶
[82]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.linspace(start=-2.2, stop=2, num=50)
weight_seq2 = weight_seq**2
posterior_predictive = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["mu", "height"]
)(
_jrng,
weight_s=weight_seq,
weight_s2=weight_seq2,
height=None,
)
mu_posterior_predictive = pd.DataFrame(posterior_predictive["mu"], columns=weight_seq)
height_posterior_predictive = pd.DataFrame(
posterior_predictive["height"], columns=weight_seq
)
[83]:
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
display(mu_mean)
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
display(mu_hpdi)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
display(height_mean)
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
display(height_hpdi)
| -2.200000 | -2.114286 | -2.028572 | -1.942857 | -1.857143 | -1.771429 | -1.685714 | -1.600000 | -1.514286 | -1.428572 | ... | 1.228571 | 1.314286 | 1.400000 | 1.485714 | 1.571429 | 1.657143 | 1.742857 | 1.828571 | 1.914286 | 2.000000 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 60.481586 | 65.229164 | 69.861977 | 74.380379 | 78.783966 | 83.07338 | 87.247482 | 91.307541 | 95.253059 | 99.083488 | ... | 160.993332 | 161.156906 | 161.205719 | 161.139893 | 160.959564 | 160.664368 | 160.255157 | 159.731079 | 159.092514 | 158.339249 |
1 rows × 50 columns
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
| -2.200000 | -2.114286 | -2.028572 | -1.942857 | -1.857143 | -1.771429 | -1.685714 | -1.600000 | -1.514286 | -1.428572 | ... | 1.228571 | 1.314286 | 1.400000 | 1.485714 | 1.571429 | 1.657143 | 1.742857 | 1.828571 | 1.914286 | 2.000000 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 59.013569 | 63.846535 | 68.564537 | 73.261055 | 77.718513 | 82.131294 | 86.331566 | 90.494766 | 94.491516 | 98.429581 | ... | 160.208939 | 160.259521 | 160.205215 | 160.036804 | 159.673935 | 159.251678 | 158.742950 | 158.031952 | 157.244263 | 156.378006 |
| 1 | 61.993671 | 66.578056 | 71.061684 | 75.535278 | 79.797653 | 84.034424 | 88.067047 | 92.095146 | 95.975159 | 99.822418 | ... | 161.825531 | 162.069397 | 162.229767 | 162.295349 | 162.181366 | 162.018936 | 161.797211 | 161.378479 | 160.898743 | 160.359406 |
2 rows × 50 columns
| -2.200000 | -2.114286 | -2.028572 | -1.942857 | -1.857143 | -1.771429 | -1.685714 | -1.600000 | -1.514286 | -1.428572 | ... | 1.228571 | 1.314286 | 1.400000 | 1.485714 | 1.571429 | 1.657143 | 1.742857 | 1.828571 | 1.914286 | 2.000000 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 60.416737 | 65.212372 | 69.902946 | 74.325066 | 78.807297 | 83.031853 | 87.275444 | 91.365974 | 95.178131 | 99.095703 | ... | 161.028168 | 161.211151 | 161.167297 | 161.197403 | 160.950775 | 160.647278 | 160.438293 | 159.657928 | 158.956055 | 158.412537 |
1 rows × 50 columns
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
| -2.200000 | -2.114286 | -2.028572 | -1.942857 | -1.857143 | -1.771429 | -1.685714 | -1.600000 | -1.514286 | -1.428572 | ... | 1.228571 | 1.314286 | 1.400000 | 1.485714 | 1.571429 | 1.657143 | 1.742857 | 1.828571 | 1.914286 | 2.000000 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 51.016739 | 55.879940 | 60.490021 | 65.458740 | 69.163567 | 73.893677 | 77.600990 | 82.287376 | 85.649536 | 90.373688 | ... | 151.999893 | 151.800873 | 152.276199 | 151.935120 | 151.718018 | 151.222717 | 150.975449 | 150.216843 | 149.342133 | 148.985733 |
| 1 | 69.423355 | 74.413803 | 79.079216 | 83.902122 | 87.854408 | 92.467232 | 96.193443 | 100.698784 | 104.460617 | 108.844330 | ... | 170.468155 | 170.176056 | 170.734787 | 170.366165 | 170.243973 | 169.753738 | 169.766205 | 168.980286 | 168.354767 | 167.897156 |
2 rows × 50 columns
Code 4.68¶
[84]:
df[["weight_s", "height"]].plot(
kind="scatter", x="weight_s", y="height", backend="matplotlib"
)
plt.plot(weight_seq, mu_mean.loc[0, :], "k")
plt.fill_between(weight_seq, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
weight_seq, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
Code 4.69¶
[85]:
df["weight_s3"] = df["weight_s"] ** 3
def m4_6(
weight_s,
weight_s2,
weight_s3,
*,
alpha_prior={"loc": 178, "scale": 20},
beta1_prior={"loc": 0, "scale": 1},
beta2_prior={"loc": 0, "scale": 10},
beta3_prior={"loc": 0, "scale": 10},
sigma_prior={"low": 0, "high": 50},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta1 = numpyro.sample("beta1", dist.LogNormal(**beta1_prior))
beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
beta3 = numpyro.sample("beta3", dist.Normal(**beta3_prior))
sigma = numpyro.sample("sigma", dist.Uniform(**sigma_prior))
mu = numpyro.deterministic(
"mu", alpha + beta1 * weight_s + beta2 * weight_s2 + beta3 * weight_s3
)
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
guide = AutoLaplaceApproximation(m4_6)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=m4_6,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.1),
loss=Trace_ELBO(),
weight_s=df["weight_s"].values,
weight_s2=df["weight_s2"].values,
weight_s3=df["weight_s3"].values,
height=df["height"].values,
).run(jrng, 5_000)
100%|███████████████████| 5000/5000 [00:01<00:00, 3025.12it/s, init loss: 5829.7710, avg. loss [4751-5000]: 1646.5316]
[86]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weight_seq = jnp.linspace(start=-2.2, stop=2, num=50)
weight_seq2 = weight_seq**2
weight_seq3 = weight_seq**3
posterior_predictive = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["mu", "height"]
)(
_jrng,
weight_s=weight_seq,
weight_s2=weight_seq2,
weight_s3=weight_seq3,
height=None,
)
mu_posterior_predictive = pd.DataFrame(posterior_predictive["mu"], columns=weight_seq)
height_posterior_predictive = pd.DataFrame(
posterior_predictive["height"], columns=weight_seq
)
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
df[["weight_s", "height"]].plot(
kind="scatter", x="weight_s", y="height", backend="matplotlib"
)
plt.plot(weight_seq, mu_mean.loc[0, :], "k")
plt.fill_between(weight_seq, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
weight_seq, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
Code 4.70¶
[87]:
df.plot(kind="scatter", x="weight_s", y="height", backend="matplotlib", xticks=[])
[87]:
<Axes: xlabel='weight_s', ylabel='height'>
Code 4.71¶
[88]:
df.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
[88]:
<Axes: xlabel='weight', ylabel='height'>
Code 4.72¶
[89]:
df = pd.read_csv("../data/cherry_blossoms.csv", sep=";")
df.describe()
[89]:
| year | doy | temp | temp_upper | temp_lower | |
|---|---|---|---|---|---|
| count | 1215.000000 | 827.000000 | 1124.000000 | 1124.000000 | 1124.000000 |
| mean | 1408.000000 | 104.540508 | 6.141886 | 7.185151 | 5.098941 |
| std | 350.884596 | 6.407036 | 0.663648 | 0.992921 | 0.850350 |
| min | 801.000000 | 86.000000 | 4.670000 | 5.450000 | 0.750000 |
| 25% | 1104.500000 | 100.000000 | 5.700000 | 6.480000 | 4.610000 |
| 50% | 1408.000000 | 105.000000 | 6.100000 | 7.040000 | 5.145000 |
| 75% | 1711.500000 | 109.000000 | 6.530000 | 7.720000 | 5.542500 |
| max | 2015.000000 | 124.000000 | 8.300000 | 12.100000 | 7.740000 |
[90]:
df["temp"].plot(backend="matplotlib")
[90]:
<Axes: >
Code 4.73¶
[91]:
df2 = df.dropna(subset=["temp"])
num_knots = 15
knots_list = jnp.quantile(
df["year"].values, jnp.linspace(start=0, stop=1, num=num_knots)
)
Code 4.74¶
[92]:
degree = 3
knots = jnp.pad(knots_list, (degree, degree), mode="edge")
B = BSpline(knots, jnp.identity(num_knots + 2), k=degree)(df2.year.values)
Code 4.75¶
[93]:
plt.subplot(
xlim=(df2.year.min(), df2.year.max()),
ylim=(0, 1),
xlabel="year",
ylabel="basis value",
)
for i in range(B.shape[1]):
plt.plot(df2.year, B[:, i], "k", alpha=0.5)
Code 4.76¶
[94]:
def m4_7(
B,
*,
alpha_prior={"loc": 6, "scale": 10},
weight_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1},
temp=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
weight = numpyro.sample(
"weight", dist.Normal(**weight_prior), sample_shape=B.shape[1:]
)
mu = numpyro.deterministic("mu", alpha + jnp.dot(B, weight))
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
temp = numpyro.sample("temp", dist.Normal(loc=mu, scale=sigma), obs=temp)
return temp
[95]:
guide = AutoLaplaceApproximation(
m4_7, init_loc_fn=numpyro.infer.init_to_value(values={"w": jnp.zeros(B.shape[1])})
)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=m4_7,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.5),
loss=Trace_ELBO(),
B=B,
temp=df2["temp"].values,
).run(jrng, 10_000)
100%|█████████████████| 10000/10000 [00:02<00:00, 3825.19it/s, init loss: 5767.3096, avg. loss [9501-10000]: 483.9176]
[96]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)
mean std median 5.5% 94.5% n_eff r_hat
alpha 6.52 0.25 6.52 6.11 6.91 9269.95 1.00
sigma 0.36 0.01 0.36 0.35 0.37 9550.55 1.00
weight[0] 2.38 0.86 2.37 1.04 3.74 9323.47 1.00
weight[1] -1.61 0.37 -1.61 -2.19 -1.03 9496.14 1.00
weight[2] 1.72 0.28 1.72 1.26 2.16 9584.92 1.00
weight[3] -0.86 0.28 -0.86 -1.32 -0.44 9534.39 1.00
weight[4] 0.15 0.27 0.15 -0.28 0.58 9329.71 1.00
weight[5] -1.84 0.27 -1.84 -2.26 -1.40 9282.52 1.00
weight[6] 0.95 0.27 0.95 0.53 1.37 9499.81 1.00
weight[7] -2.04 0.27 -2.04 -2.47 -1.61 9144.56 1.00
weight[8] 1.86 0.27 1.86 1.42 2.26 9494.26 1.00
weight[9] -2.13 0.27 -2.12 -2.56 -1.70 9301.43 1.00
weight[10] 0.44 0.27 0.45 0.03 0.88 9275.83 1.00
weight[11] -1.65 0.27 -1.65 -2.10 -1.24 9294.14 1.00
weight[12] -0.10 0.27 -0.10 -0.52 0.33 9388.49 1.00
weight[13] -1.48 0.27 -1.48 -1.93 -1.05 9098.48 1.00
weight[14] 0.39 0.28 0.39 -0.05 0.85 9669.49 1.00
weight[15] 1.91 0.34 1.91 1.39 2.47 9564.52 1.00
weight[16] 1.86 0.76 1.86 0.65 3.07 9867.93 1.00
Code 4.77¶
[97]:
weight_mean = posterior_samples["weight"].mean(axis=0)
plt.subplot(
xlim=(df2.year.min(), df2.year.max()),
ylim=(-2, 2),
xlabel="year",
ylabel="basis * weight",
)
for i in range(B.shape[1]):
plt.plot(df2["year"], weight_mean[i] * B[:, i], "k", alpha=0.2)
Code 4.78¶
[98]:
mu_hpdi = pd.DataFrame(numpyro.diagnostics.hpdi(posterior_samples["mu"], prob=0.89))
df2.plot(kind="scatter", x="year", y="temp", backend="matplotlib", figsize=(10, 6))
plt.fill_between(
df2["year"], mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5
)
[98]:
<matplotlib.collections.PolyCollection at 0x7d87057d41a0>
Code 4.79¶
[99]:
def m4_7_alt(
B,
*,
alpha_prior={"loc": 6, "scale": 10},
weight_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1},
temp=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
weight = numpyro.sample(
"weight", dist.Normal(**weight_prior), sample_shape=B.shape[1:]
)
mu = numpyro.deterministic("mu", alpha + jnp.sum(B * weight, axis=-1))
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
temp = numpyro.sample("temp", dist.Normal(loc=mu, scale=sigma), obs=temp)
return temp
guide = AutoLaplaceApproximation(
m4_7_alt,
init_loc_fn=numpyro.infer.init_to_value(values={"w": jnp.zeros(B.shape[1])}),
)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=m4_7_alt,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.5),
loss=Trace_ELBO(),
B=B,
temp=df2["temp"].values,
).run(jrng, 10_000)
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)
100%|█████████████████| 10000/10000 [00:03<00:00, 3220.22it/s, init loss: 5767.3096, avg. loss [9501-10000]: 483.9176]
mean std median 5.5% 94.5% n_eff r_hat
alpha 6.52 0.25 6.52 6.13 6.92 9591.55 1.00
sigma 0.36 0.01 0.36 0.35 0.37 10219.07 1.00
weight[0] 2.39 0.86 2.39 1.12 3.85 9671.32 1.00
weight[1] -1.61 0.36 -1.61 -2.17 -1.05 9958.04 1.00
weight[2] 1.72 0.28 1.72 1.26 2.17 9412.93 1.00
weight[3] -0.86 0.27 -0.86 -1.28 -0.42 9821.61 1.00
weight[4] 0.15 0.27 0.15 -0.26 0.59 9517.43 1.00
weight[5] -1.84 0.27 -1.84 -2.28 -1.43 9679.75 1.00
weight[6] 0.96 0.26 0.95 0.51 1.36 9832.37 1.00
weight[7] -2.04 0.27 -2.04 -2.45 -1.60 9217.47 1.00
weight[8] 1.86 0.26 1.86 1.44 2.28 9860.78 1.00
weight[9] -2.13 0.27 -2.13 -2.54 -1.70 9499.83 1.00
weight[10] 0.45 0.26 0.45 0.03 0.86 9563.80 1.00
weight[11] -1.65 0.27 -1.65 -2.06 -1.21 9835.99 1.00
weight[12] -0.10 0.27 -0.10 -0.52 0.32 9681.86 1.00
weight[13] -1.48 0.27 -1.48 -1.91 -1.04 9888.25 1.00
weight[14] 0.39 0.28 0.39 -0.06 0.84 9416.45 1.00
weight[15] 1.91 0.34 1.91 1.38 2.46 9785.94 1.00
weight[16] 1.86 0.76 1.86 0.64 3.05 9780.90 1.00
Easy¶
4E1¶
The first line is the likelihood.
4E2¶
2 parameters, mu and sigma
4E3¶
\[P(\mu, \sigma | y) = \frac{P(y | \mu, \sigma) \cdot P(\mu) \cdot P(\sigma)}{\int \int P(y | \mu, \sigma) \cdot P(\mu) \cdot P(\sigma) d \mu d \sigma}\]
4E4¶
The second line is the linear model
4E5¶
Three parameters, \(\alpha\), \(\beta\) and \(\sigma\)
Medium¶
4M1¶
[100]:
def m1(y):
mu = numpyro.sample("mu", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
prior_predictive_samples = pd.DataFrame(
numpyro.infer.Predictive(m1, num_samples=10_000)(jrng, y=None)
)
prior_predictive_samples["y"].plot(kind="kde", backend="matplotlib")
[100]:
<Axes: ylabel='Density'>
4M2¶
[101]:
def m1(y):
mu = numpyro.sample("mu", dist.Normal(0, 10))
sigma = numpyro.sample("sigma", dist.Exponential(1))
return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
4M3¶
\[\begin{split}\begin{split}
y & \sim Normal(\mu, \sigma) \\
\mu & = a + b \cdot x \\
\alpha & \sim Normal(0, 10) \\
\beta & \sim Uniform(0, 1) \\
\sigma & \sim Exponential(1)
\end{split}\end{split}\]
4M4¶
\[\begin{split}\begin{split}
d_height & \sim Normal(\mu, \sigma) \\
\mu & = a + b \cdot (year - start year) \\
\alpha & \sim Normal(170, 29) \\
\beta & \sim Normal(0, 1) \\
\sigma & \sim Exponential(1)
\end{split}\end{split}\]
4M5¶
\[\begin{split}\begin{split}
d_height & \sim Normal(\mu, \sigma) \\
\mu & = a + b \cdot (year - start year) \\
\alpha & \sim Normal(170, 29) \\
\beta & \sim LogNormal(0, 1) \\
\sigma & \sim Exponential(1)
\end{split}\end{split}\]
4M6¶
We don’t want to change prior by peeking at data.
Hard¶
4H1¶
[102]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df["weight_z_score"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()
def h1(
weight_z_score,
*,
alpha_prior={"loc": 170, "scale": 20},
beta1_prior={"loc": 0, "scale": 1},
beta2_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1 / 10},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta1 = numpyro.sample("beta1", dist.Normal(**beta1_prior))
beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
mu = numpyro.deterministic(
"mu", alpha + beta1 * weight_z_score + beta2 * weight_z_score**2
)
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
height_forecast = numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
return height_forecast
[103]:
# plot 100 random samples from the prior predictive distribution
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
prior_predictive_samples = numpyro.infer.Predictive(h1, num_samples=10_000)(
jrng, weight_z_score=weight_z_score, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
n=100, random_state=seed
).iterrows():
plt.plot(
weights,
sample,
"k",
alpha=0.2,
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[103]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
[104]:
# plot mean and hpdi of prior predictive distribution
mu_mean = pd.DataFrame(prior_predictive_samples["mu"].mean(axis=0), index=weights)
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(prior_predictive_samples["mu"], prob=0.89),
columns=weights,
index=["low", "high"],
)
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(prior_predictive_samples["height"], prob=0.89),
columns=weights,
index=["low", "high"],
)
# df.plot(kind="scatter", x="weight", y="height", backend="matplotlib")
plt.plot(weights, mu_mean, "k")
plt.fill_between(weights, mu_hpdi.iloc[0, :], mu_hpdi.iloc[1, :], color="k", alpha=0.5)
plt.fill_between(
weights, height_hpdi.iloc[0, :], height_hpdi.iloc[1, :], color="k", alpha=0.2
)
# plt.show()
[104]:
<matplotlib.collections.PolyCollection at 0x7d8701bc4fb0>
[105]:
guide = AutoLaplaceApproximation(h1)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=h1,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.5),
loss=Trace_ELBO(),
weight_z_score=df["weight_z_score"].values,
height=df["height"].values,
).run(jrng, 10_000)
100%|█████████████| 10000/10000 [00:02<00:00, 4077.02it/s, init loss: 2903255.0000, avg. loss [9501-10000]: 1981.9767]
[106]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
posterior_predictive = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
_jrng,
weight_z_score=weight_z_score,
height=None,
)
mu_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["mu"], columns=weights
)
height_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["height"], columns=weights
)
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight_seq
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight_seq
)
df[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weights, mu_mean.loc[0, :], "k")
plt.fill_between(weights, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
weights, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
[107]:
weights = jnp.array([46.95, 43.72, 64.68, 32.59, 54.63])
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
posterior_predictive_samples = posterior_predictive(
_jrng,
weight_z_score=weight_z_score,
height=None,
)
mu_posterior_predictive_samples = pd.DataFrame(posterior_predictive_samples["mu"])
height_posterior_predictive_samples = pd.DataFrame(
posterior_predictive_samples["height"]
)
mu_mean = mu_posterior_predictive_samples.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive_samples, prob=0.89),
)
height_mean = height_posterior_predictive_samples.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive_samples, prob=0.89),
)
pd.concat(
[
pd.DataFrame(weights, columns=["weight"]),
height_mean.T.rename(columns={0: "expected height"}),
height_hpdi.T.rename(columns={0: "89% hpdi low", 1: "89% hpdi high"}),
],
axis=1,
)
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
[107]:
| weight | expected height | 89% hpdi low | 89% hpdi high | |
|---|---|---|---|---|
| 0 | 46.950001 | 157.183014 | 147.768356 | 166.736542 |
| 1 | 43.720001 | 155.287338 | 145.906158 | 164.688492 |
| 2 | 64.680000 | 152.824203 | 142.585907 | 162.118515 |
| 3 | 32.590000 | 142.363800 | 132.973160 | 151.893936 |
| 4 | 54.630001 | 158.352539 | 148.645493 | 167.660065 |
4H2¶
[108]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df = df.loc[df["age"] < 18, :]
df
[108]:
| height | weight | age | male | |
|---|---|---|---|---|
| 18 | 121.920 | 19.617854 | 12.0 | 1 |
| 19 | 105.410 | 13.947954 | 8.0 | 0 |
| 20 | 86.360 | 10.489315 | 6.5 | 0 |
| 23 | 129.540 | 23.586784 | 13.0 | 1 |
| 24 | 109.220 | 15.989118 | 7.0 | 0 |
| ... | ... | ... | ... | ... |
| 535 | 114.935 | 17.519991 | 7.0 | 1 |
| 536 | 67.945 | 7.229122 | 1.0 | 0 |
| 538 | 76.835 | 8.022908 | 1.0 | 1 |
| 539 | 145.415 | 31.127751 | 17.0 | 1 |
| 542 | 71.120 | 8.051258 | 0.0 | 1 |
192 rows × 4 columns
[109]:
def h2(
weight,
*,
alpha_prior={"loc": 100, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1 / 10},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
mu = numpyro.deterministic("mu", alpha + beta * weight)
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
[110]:
# plot 100 random samples from the prior predictive distribution
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
prior_predictive_samples = numpyro.infer.Predictive(h2, num_samples=10_000)(
jrng, weight=weight, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
n=100, random_state=seed
).iterrows():
plt.plot(
weight,
sample,
"k",
alpha=0.2,
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[110]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
[111]:
guide = AutoLaplaceApproximation(h2)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=h2,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.5),
loss=Trace_ELBO(),
weight=df["weight"].values,
height=df["height"].values,
).run(jrng, 10_000)
100%|███████████████| 10000/10000 [00:02<00:00, 3982.20it/s, init loss: 232238.2500, avg. loss [9501-10000]: 691.3686]
[112]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)
mean std median 5.5% 94.5% n_eff r_hat
alpha 58.81 1.26 58.80 56.72 60.74 9597.75 1.00
beta 2.83 0.06 2.83 2.73 2.93 9858.46 1.00
sigma 8.29 0.44 8.28 7.59 8.98 9915.53 1.00
For every 10 kg increase in weight, we expect 28cm increase in height
[113]:
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
posterior_predictive = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
_jrng,
weight=weight,
height=None,
)
mu_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["mu"], columns=weight
)
height_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["height"], columns=weight
)
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight
)
df[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weight, mu_mean.loc[0, :], "k")
plt.fill_between(weight, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
weight, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
Data appears to have curvature that’s not captured by linear model.
4H3¶
[114]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
def h3(
weight,
*,
alpha_prior={"loc": 178, "scale": 20},
beta_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1 / 10},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta = numpyro.sample("beta", dist.LogNormal(**beta_prior))
mu = numpyro.deterministic("mu", alpha + beta * jnp.log(weight))
# mu = numpyro.deterministic("mu", alpha + beta * weight)
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
height = numpyro.sample("height", dist.Normal(loc=mu, scale=sigma), obs=height)
return height
[115]:
# plot 100 random samples from the prior predictive distribution
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
prior_predictive_samples = numpyro.infer.Predictive(h3, num_samples=10_000)(
jrng, weight=weight, height=None
)
mu_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["mu"])
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in mu_prior_predictive_samples.sample(
n=100, random_state=seed
).iterrows():
plt.plot(
weight,
sample,
"k",
alpha=0.2,
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Expected Height Samples")
[115]:
Text(0.5, 1.0, 'Prior Predictive Expected Height Samples')
[116]:
guide = AutoLaplaceApproximation(h3)
_, _jrng = jax.random.split(_jrng)
svi = SVI(
model=h3,
guide=guide,
optim=numpyro.optim.Adam(step_size=0.5),
loss=Trace_ELBO(),
weight=df["weight"].values,
height=df["height"].values,
).run(jrng, 10_000)
100%|█████████████| 10000/10000 [00:02<00:00, 3923.97it/s, init loss: 1202857.2500, avg. loss [9501-10000]: 2229.7351]
[117]:
_, _jrng = jax.random.split(_jrng)
posterior_samples = guide.sample_posterior(_jrng, svi.params, sample_shape=(10_000,))
_posterior_samples = {k: v for k, v in posterior_samples.items() if k != "mu"}
numpyro.diagnostics.print_summary(_posterior_samples, prob=0.89, group_by_chain=False)
mean std median 5.5% 94.5% n_eff r_hat
alpha -18.85 1.38 -18.84 -21.15 -16.72 9436.35 1.00
beta 45.72 0.39 45.72 45.11 46.37 9430.39 1.00
sigma 5.24 0.16 5.24 4.98 5.50 9364.90 1.00
[118]:
weight = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
posterior_predictive = numpyro.infer.Predictive(
guide.model, posterior_samples, return_sites=["mu", "height"]
)
posterior_predictive_samples = posterior_predictive(
_jrng,
weight=weight,
height=None,
)
mu_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["mu"], columns=weight
)
height_posterior_predictive = pd.DataFrame(
posterior_predictive_samples["height"], columns=weight
)
mu_mean = mu_posterior_predictive.mean(axis=0).to_frame().T
mu_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(mu_posterior_predictive, prob=0.89), columns=weight
)
height_mean = height_posterior_predictive.mean(axis=0).to_frame().T
height_hpdi = pd.DataFrame(
numpyro.diagnostics.hpdi(height_posterior_predictive, prob=0.89), columns=weight
)
df[["weight", "height"]].plot(
kind="scatter", x="weight", y="height", backend="matplotlib"
)
plt.plot(weight, mu_mean.loc[0, :], "k")
plt.fill_between(weight, mu_hpdi.loc[0, :], mu_hpdi.loc[1, :], color="k", alpha=0.5)
plt.fill_between(
weight, height_hpdi.loc[0, :], height_hpdi.loc[1, :], color="k", alpha=0.2
)
plt.show()
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
/home/ltiako/.local/share/hatch/env/virtual/rethinking/ANFAH7h_/rethinking/lib/python3.12/site-packages/numpy/core/fromnumeric.py:59: FutureWarning:
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
4H4¶
[119]:
df = pd.read_csv("../data/Howell1.csv", sep=";")
df["weight_z_score"] = (df["weight"] - df["weight"].mean()) / df["weight"].std()
def h4(
weight_z_score,
*,
alpha_prior={"loc": 170, "scale": 20},
beta1_prior={"loc": 0, "scale": 1},
beta2_prior={"loc": 0, "scale": 1},
sigma_prior={"rate": 1 / 10},
height=None,
):
alpha = numpyro.sample("alpha", dist.Normal(**alpha_prior))
beta1 = numpyro.sample("beta1", dist.Normal(**beta1_prior))
beta2 = numpyro.sample("beta2", dist.Normal(**beta2_prior))
mu = numpyro.deterministic(
"mu", alpha + beta1 * weight_z_score + beta2 * weight_z_score**2
)
sigma = numpyro.sample("sigma", dist.Exponential(**sigma_prior))
height_forecast = numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
return height_forecast
[120]:
# plot 100 random samples from the prior predictive distribution
weights = jnp.linspace(df["weight"].min(), df["weight"].max(), 50)
weight_z_score = (weights - df["weight"].mean()) / df["weight"].std()
prior_predictive_samples = numpyro.infer.Predictive(h4, num_samples=10_000)(
jrng, weight_z_score=weight_z_score, height=None
)
height_prior_predictive_samples = pd.DataFrame(prior_predictive_samples["height"])
for _, sample in height_prior_predictive_samples.sample(
n=100, random_state=seed
).iterrows():
plt.plot(
weights,
sample,
"k",
alpha=0.2,
)
plt.axhline(y=0, c="k", ls="--")
plt.axhline(y=272, c="k", ls="-", lw=0.5)
plt.title("Prior Predictive Height Samples")
[120]:
Text(0.5, 1.0, 'Prior Predictive Height Samples')
[ ]: